from collections import defaultdict
import random
from hypersense.sampler.base_sampler import BaseSampler


class StratifiedSampler(BaseSampler):
    """
    Stratified sampler that preserves class distribution in the sampled subset.
    Assumes that the last element in each sample is the class label.
    
    Future improvements:
    1. Add support for multi-label classification.
    2. Implement Stratified KMeans Bucket Sampling.
    3. Implement Stratified KMeans Sampling.
    4. Allow explicit specification of the "label column index" or column name (instead of defaulting to the last column).
    """

    def sample(self):
        """
        Perform stratified sampling on the dataset.

        Returns:
            List[Any]: The sampled subset of the dataset.
        """
        rng = random.Random(self.seed)
        
        # Group samples by class label (assumes last column is the label)
        label_to_samples = defaultdict(list)
        for row in self.dataset:
            label = row[-1]  # Assumes last element is the class label
            label_to_samples[label].append(row)

        # Determine number of samples per class
        total_size = len(self.dataset)
        if self.sample_size > total_size:
            raise ValueError(f"Sample size ({self.sample_size}) exceeds dataset size ({total_size}).")

        sampled_subset = []
        for label, samples in label_to_samples.items():
            proportion = len(samples) / total_size
            num_to_sample = int(round(self.sample_size * proportion))
            sampled_subset.extend(rng.sample(samples, min(num_to_sample, len(samples))))
        # Ensure exact sample size (e.g. rounding error adjustment)
        if len(sampled_subset) > self.sample_size:
            sampled_subset = rng.sample(sampled_subset, self.sample_size)

        return sampled_subset
